/**
 * 
 */
package edu.berkeley.nlp.PCFGLA;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Random;

import edu.berkeley.nlp.util.ArrayUtil;
import fig.basic.Indexer;
import edu.berkeley.nlp.util.Numberer;

/**
 * @author petrov
 *
 */
public class HierarchicalLexicon extends SimpleLexicon {

	private static final long serialVersionUID = 1L;
	public List<double[]>[][] hierarchicalScores; // for each tag, word store a list of hiearchical features
	public int[][] finalLevels;
	
	/**
	 * @param numSubStates
	 * @param threshold
	 */
	public HierarchicalLexicon(short[] numSubStates, double threshold) {
		super(numSubStates, threshold);
		hierarchicalScores = new List[numStates][];
	}

	public HierarchicalLexicon(SimpleLexicon lex){
  	super(lex.numSubStates,lex.threshold);
  	this.expectedCounts = new double[numStates][][];
  	this.tagWordIndexer = new IntegerIndexer[numStates];
  	this.wordIndexer = lex.wordIndexer;
  	this.wordCounter = lex.wordCounter;
//  	this.wordIsAmbiguous = lex.wordIsAmbiguous;
  	for (int tag=0; tag<numStates; tag++){
  		this.tagWordIndexer[tag] = lex.tagWordIndexer[tag].copy();
  	}
  	this.nWords = lex.nWords;
  	this.smoother = lex.smoother;
  	makeHiearchicalScores(lex.scores);
  	this.scores = null;
  }

	// assume for now that the scores being passed in are from an unsplit baseline grammar
	private void makeHiearchicalScores(double[][][] scores) {
		hierarchicalScores = new List[numStates][];
		finalLevels = new int[numStates][];
  	for (int tag=0; tag<numStates; tag++){
  		int words = tagWordIndexer[tag].size();
  		hierarchicalScores[tag] = new List[words];
  		finalLevels[tag] = new int[words];
  		for (int word=0; word<words; word++){
  			hierarchicalScores[tag][word] = new ArrayList<double[]>();
  			double[] score = {Math.log(scores[tag][0][word])};
  			hierarchicalScores[tag][word].add(score);
  			//finalLevels[tag][word]=0; // already initialized to 0
  		}
  	}
	}

	public void explicitlyComputeScores(int finalLevel){
		this.scores = new double[numStates][][];
		int nSubstates = (int)Math.pow(2, finalLevel);
//		int[] divisors = new int[nSubstates];//finalLevel+1];
//		for (int i=0; i<=finalLevel; i++){
//			int div = (int)Math.pow(2, finalLevel-i);
//			divisors[div] = div;
//		}

		for (int tag=0; tag<numStates; tag++){
  		int words = hierarchicalScores[tag].length;
			this.scores[tag] = new double[nSubstates][words];
			for (int word=0; word<words; word++){
				List<double[]> scoreHierarchy = hierarchicalScores[tag][word];
				for (int level=0; level<=finalLevel; level++){
					if (level>finalLevels[tag][word]) 
						continue;
					double[] scoresThisLevel = scoreHierarchy.get(level); 
					int divisor = nSubstates/scoresThisLevel.length; // divisors[level];
					for (int substate=0; substate<nSubstates; substate++){
						this.scores[tag][substate][word] += scoresThisLevel[substate/divisor];
					}
				}
				for (int substate=0; substate<nSubstates; substate++){
					this.scores[tag][substate][word] = Math.exp(scores[tag][substate][word]);
				}
			}
  	}
	}
	
	public HierarchicalLexicon splitAllStates(int[] counts, boolean moreSubstatesThanCounts, int mode){
		short[] newNumSubStates = new short[numSubStates.length];
		newNumSubStates[0] = 1;
		for (short i = 1; i < numSubStates.length; i++) {
			// don't split a state into more substates than times it was actaully seen
			if (!moreSubstatesThanCounts && numSubStates[i]>=counts[i]) { 
				newNumSubStates[i]=numSubStates[i];
			} 
			else{
				newNumSubStates[i] = (short)(numSubStates[i] * 2);
			}
		}
		HierarchicalLexicon newLex = newInstance();
		newLex.numSubStates = newNumSubStates;
		Random random = GrammarTrainer.RANDOM;
		newLex.expectedCounts = new double[numStates][][];
		newLex.tagWordIndexer = new IntegerIndexer[numStates];
		newLex.wordIndexer = this.wordIndexer;
  	for (int tag=0; tag<numStates; tag++){
  		newLex.tagWordIndexer[tag] = tagWordIndexer[tag].copy();
  	}
  	newLex.nWords = this.nWords;
  	newLex.smoother = this.smoother;
		List<double[]>[][] hS = new List[numStates][];
		newLex.finalLevels = new int[numStates][];
//		int[] nSubstates = new int[finalLevel+1];
//		for (int i=0; i<=finalLevel; i++){
//			nSubstates[i] = (int)Math.pow(2, i);
//		}
  	for (int tag=0; tag<numStates; tag++){
  		int words = tagWordIndexer[tag].size();
  		hS[tag] = new List[words];
  		newLex.finalLevels[tag] = new int[words];
  		for (int word=0; word<words; word++){
  			hS[tag][word] = new ArrayList<double[]>();
  			for (double[] scores : hierarchicalScores[tag][word]){
  				hS[tag][word].add(scores.clone());
  			}
  			int fLevel = this.finalLevels[tag][word]+1;
  			int nSub = (int)Math.pow(2, fLevel);
  			if (nSub > newNumSubStates[tag]) continue;
  			double[] newScores = new double[nSub];
  			for (int i=0; i<newScores.length; i++){
  				newScores[i] = random.nextDouble()/100.0;
  			}
  			hS[tag][word].add(newScores);
  			newLex.finalLevels[tag][word] = fLevel;
  		}
  	}
  	newLex.scores = null;
  	newLex.hierarchicalScores = hS;
  	newLex.wordCounter = wordCounter;
//  	newLex.wordIsAmbiguous = wordIsAmbiguous;
  	return newLex;
	}
	
	/**
	 * @return
	 */
	public HierarchicalLexicon newInstance() {
		return new HierarchicalLexicon(this.numSubStates,this.threshold);
	}

	public int getFinalLevel(int  globalWordIndex, int tag){
		int tagSpecificWordIndex = tagWordIndexer[tag].indexOf(globalWordIndex);
		return finalLevels[tag][tagSpecificWordIndex];
	}
	
	public void mergeLexicon(){
		int nRemovedParam = 0, nRemovedArrays = 0;
		for (int tag=0; tag<numStates; tag++){
  		int words = hierarchicalScores[tag].length;
			for (int word=0; word<words; word++){
				List<double[]> scoreHierarchy = hierarchicalScores[tag][word];
				int level = finalLevels[tag][word];
				double[] scoresThisLevel = scoreHierarchy.get(level); 
				if (scoresThisLevel == null) continue;
				boolean allZero = true;
				for (int substate=0; substate<scoresThisLevel.length; substate++){
					allZero = allZero && scoresThisLevel[substate]==0;
				}
				if (allZero) {
					scoreHierarchy.remove(level);
					finalLevels[tag][word]--;
					nRemovedParam += scoresThisLevel.length;
					nRemovedArrays++;
				}
			}
  	}
		System.out.println("Removed "+nRemovedParam+" parameters in the lexicon by setting "+nRemovedArrays+" arrays to null.");
	}

	public double[] getLastLevel(int tag, int word) {
		return hierarchicalScores[tag][word].get(finalLevels[tag][word]);
	}

  public HierarchicalLexicon copyLexicon(){
  	HierarchicalLexicon copy = newInstance();
  	copy.expectedCounts = new double[numStates][][];
  	copy.scores = ArrayUtil.clone(scores);//new double[numStates][][];
  	copy.hierarchicalScores = this.hierarchicalScores;
  	copy.tagWordIndexer = new IntegerIndexer[numStates];
  	copy.wordIndexer = this.wordIndexer;
  	for (int tag=0; tag<numStates; tag++){
  		copy.tagWordIndexer[tag] = tagWordIndexer[tag].copy();
  		copy.expectedCounts[tag] = new double[numSubStates[tag]][tagWordIndexer[tag].size()];
  	}
  	if (this.wordCounter!=null) copy.wordCounter = this.wordCounter.clone();
//  	if (this.wordIsAmbiguous!=null) copy.wordIsAmbiguous = this.wordIsAmbiguous.clone();
  	copy.nWords = this.nWords;
  	copy.smoother = this.smoother;
  	if (finalLevels!=null) copy.finalLevels = ArrayUtil.clone(this.finalLevels);
  	return copy;
  }
  
  public String toString() {
  	StringBuffer sb = new StringBuffer();
  	Numberer tagNumberer = Numberer.getGlobalNumberer("tags");
  	for (int tag=0; tag<scores.length; tag++){
    	int[] counts = new int[6];
  		String tagS = (String)tagNumberer.object(tag);
  		if (tagWordIndexer[tag].size()==0) continue;
  		for (int word=0; word<scores[tag][0].length; word++){
  			sb.append(tagS+" "+ wordIndexer.get(tagWordIndexer[tag].get(word))+" ");
  			for (int sub=0; sub<numSubStates[tag]; sub++){
  				sb.append(" " + scores[tag][sub][word]);
  			}
  			for (double[] d : hierarchicalScores[tag][word]){
  				sb.append("\n"+Arrays.toString(d));
  			}
  			counts[finalLevels[tag][word]]++;
  			sb.append("\n\n");
  		}
  		System.out.print(tagNumberer.object(tag)+", word,tag pairs per level: ");
  		for (int i=1; i<6; i++){
  			System.out.print(counts[i]+" ");
  		}
  		System.out.print("\n");
  	}
  	return sb.toString();
  }

}
